package com.derbysoft.nuke.kafka.manager.infrastructure.kafka;

import com.derbysoft.nuke.kafka.manager.infrastructure.kafka.model.*;
import kafka.api.OffsetFetchRequest;
import kafka.api.OffsetFetchResponse;
import kafka.common.OffsetMetadataAndError;
import kafka.common.TopicAndPartition;
import kafka.coordinator.group.GroupOverview;
import kafka.network.BlockingChannel;
import org.apache.kafka.clients.NodeApiVersions;
import org.apache.kafka.clients.consumer.*;
import org.apache.kafka.common.Node;
import org.apache.kafka.common.PartitionInfo;
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.serialization.StringDeserializer;
import scala.collection.JavaConversions;
import scala.collection.JavaConverters;
import scala.collection.Seq;

import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.stream.Collectors;

public class DefaultKafkaAdminClient implements KafkaAdminClient {

    private AtomicInteger correlationId = new AtomicInteger(0);
    private String clientId = UUID.randomUUID().toString();

    private String bootstrapServers;
    private kafka.admin.AdminClient adminClient;

    // share consumer, is only used to fetch metadata not messages.
    private KafkaConsumer<String, String> kafkaConsumer;

    public DefaultKafkaAdminClient(String bootstrapServers) {
        this.bootstrapServers = bootstrapServers;
        adminClient = getAdminClient();
        initKafkaConsumer(bootstrapServers);
    }

    private void initKafkaConsumer(String bootstrapServers) {
        Map<String, Object> conf = new HashMap<>();
        conf.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, false);
        conf.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers);
        conf.put(ConsumerConfig.GROUP_ID_CONFIG, "group-" + clientId);
        conf.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class.getName());
        conf.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class.getName());
        kafkaConsumer = new KafkaConsumer(conf);
    }

    private kafka.admin.AdminClient getAdminClient() {
        Properties properties = new Properties();
        properties.put("bootstrap.servers", bootstrapServers);
        return kafka.admin.AdminClient.create(properties);
    }

    public List<Message<String, String>> fetchMessages(TopicPartition topicPartition, long startOffset, int limit) {
        limit = Math.min(limit, 100);
        KafkaConsumer<String, String> kafkaConsumer = createKafkaConsumer(limit);
        try {
            return doSingleFetchMessage(kafkaConsumer, topicPartition, startOffset);
        } finally {
            kafkaConsumer.close();
        }
    }

    @Override
    public List<Message<String, String>> fetchMessages(Map<TopicPartition, Long> offsets, int limitPerPartition) {
        KafkaConsumer<String, String> kafkaConsumer = createKafkaConsumer(limitPerPartition);
        try {
            return offsets.entrySet().stream()
                    .map(e -> doSingleFetchMessage(kafkaConsumer, e.getKey(), e.getValue()))
                    .flatMap(Collection::stream)
                    .collect(Collectors.toList());
        } finally {
            kafkaConsumer.close();
        }
    }

    private List<Message<String, String>> doSingleFetchMessage(KafkaConsumer<String, String> kafkaConsumer, TopicPartition topicPartition, long startOffset) {
        startOffset = Math.max(0, startOffset);
        kafkaConsumer.assign(Collections.singleton(topicPartition));
        kafkaConsumer.seek(topicPartition, startOffset);
        ConsumerRecords<String, String> consumerRecords = kafkaConsumer.poll(100);
        List<ConsumerRecord<String, String>> records = consumerRecords.records(topicPartition);
        return records.stream()
                .map(this::toMessage)
                .collect(Collectors.toList());
    }

    private KafkaConsumer createKafkaConsumer(int limit) {
        Map<String, Object> conf = new HashMap<>();
        conf.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, false);
        conf.put(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG, bootstrapServers);
//        conf.put(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, String.valueOf(limit));
//        conf.put(ConsumerConfig.FETCH_MAX_BYTES_CONFIG, 5 * 1024 * 1024); // 5M， default: 64KB
        conf.put(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class.getName());
        conf.put(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, StringDeserializer.class.getName());
        conf.put(ConsumerConfig.GROUP_ID_CONFIG, "fetch-message-" + UUID.randomUUID().toString());
        return new KafkaConsumer(conf);
    }

    private Message<String, String> toMessage(ConsumerRecord<String, String> consumerRecord) {
        return Message.<String, String>builder()
                .topic(consumerRecord.topic())
                .partition(consumerRecord.partition())
                .offset(consumerRecord.offset())
                .timestamp(consumerRecord.timestamp())
                .key(consumerRecord.key())
                .value(consumerRecord.value())
                .build();
    }

    @Override
    public Map<String, List<PartitionInfo>> listTopics() {
        return kafkaConsumer.listTopics();
    }

    @Override
    public List<PartitionInfo> getPartitions(String topic) {
        return kafkaConsumer.partitionsFor(topic);
    }

    @Override
    public PartitionOffset getTopicPartitionOffset(TopicPartition topicPartition) {
        return getPartitionOffsets(Collections.singletonList(topicPartition)).get(topicPartition);
    }

    @Override
    public Map<Integer, PartitionOffset> getTopicOffsets(String topic) {
        List<TopicPartition> topicPartitions = toTopicPartitions(getPartitions(topic));
        return getPartitionOffsets(topicPartitions).entrySet().stream()
                .collect(Collectors.toMap(e -> e.getKey().partition(), e -> e.getValue()));
    }

    @Override
    public Map<Integer, OffsetTimestamp> getStartOffsetTimestamps(String topic) {
        List<TopicPartition> topicPartitions = toTopicPartitions(getPartitions(topic));
        return topicPartitions.stream()
                .collect(Collectors.toMap(TopicPartition::partition, tp -> getTopicPartitionOffset(tp, 0L).orElse(OffsetTimestamp.empty())));
    }

    @Override
    public Map<Integer, OffsetTimestamp> getEndOffsetTimestamps(String topic) {
        List<TopicPartition> topicPartitions = toTopicPartitions(getPartitions(topic));
        Map<TopicPartition, PartitionOffset> partitionOffsets = getPartitionOffsets(topicPartitions);

        Map<Integer, OffsetTimestamp> offsetTimestampMap = new HashMap<>();
        KafkaConsumer<String, String> kafkaConsumer = createKafkaConsumer(1);
        try {
            for (TopicPartition topicPartition : topicPartitions) {
                PartitionOffset partitionOffset = partitionOffsets.get(topicPartition);
                long end = partitionOffset.getEnd() > 0 ? partitionOffset.getEnd() - 1 : 0;
                List<Message<String, String>> messages = doSingleFetchMessage(kafkaConsumer, topicPartition, end);// fetch last message
                long timestamp = 0;
                if (!messages.isEmpty()) {
                    timestamp = messages.get(0).getTimestamp();
                }
                offsetTimestampMap.put(topicPartition.partition(), OffsetTimestamp.builder().offset(end).timestamp(timestamp).build());
            }
        } finally {
            kafkaConsumer.close();
        }
        return offsetTimestampMap;
    }

    @Override
    public Map<TopicPartition, PartitionOffset> getAllTopicPartitionOffsets() {
        Map<String, List<PartitionInfo>> topicPartitionInfos = listTopics();
        List<TopicPartition> topicPartitions = topicPartitionInfos.values().stream()
                .flatMap(Collection::stream)
                .map(this::toTopicPartition)
                .collect(Collectors.toList());
        return getPartitionOffsets(topicPartitions);
    }

    private List<TopicPartition> toTopicPartitions(List<PartitionInfo> partitions) {
        return partitions.stream()
                .map(this::toTopicPartition)
                .collect(Collectors.toList());
    }

    private TopicPartition toTopicPartition(PartitionInfo partitionInfo) {
        return new TopicPartition(partitionInfo.topic(), partitionInfo.partition());
    }

    private Map<TopicPartition, PartitionOffset> getPartitionOffsets(List<TopicPartition> topicPartitions) {
        Map<TopicPartition, Long> beginningOffsets = kafkaConsumer.beginningOffsets(topicPartitions);
        Map<TopicPartition, Long> endOffsets = kafkaConsumer.endOffsets(topicPartitions);
        return topicPartitions.stream()
                .collect(Collectors.toMap(Function.identity(), t -> this.getPartitionOffset(t, beginningOffsets, endOffsets)));
    }

    private PartitionOffset getPartitionOffset(TopicPartition topicPartition, Map<TopicPartition, Long> beginningOffsets, Map<TopicPartition, Long> endOffsets) {
        return PartitionOffset.builder()
                .start(beginningOffsets.getOrDefault(topicPartition, 0L))
                .end(endOffsets.getOrDefault(topicPartition, 0L))
                .build();
    }

    @Override
    public Optional<OffsetTimestamp> getTopicPartitionOffset(TopicPartition topicPartition, Long time) {
        Map<TopicPartition, OffsetAndTimestamp> offsets = kafkaConsumer.offsetsForTimes(Collections.singletonMap(topicPartition, time));
        return Optional.ofNullable(toOffsetTimestamp(offsets.get(topicPartition)));
    }

    private OffsetTimestamp toOffsetTimestamp(OffsetAndTimestamp offsetAndTimestamp) {
        if (offsetAndTimestamp == null) return null;
        return OffsetTimestamp.builder()
                .timestamp(offsetAndTimestamp.timestamp())
                .offset(offsetAndTimestamp.offset())
                .build();
    }

    @Override
    public Map<Integer, PartitionOffset> searchTopicOffsetsBetween(String topic, Long startTime, Long endTime) {
        final Long searchStartTime = startTime == null ? 0L : startTime;
        final Long searchEndTime = endTime == null ? System.currentTimeMillis() : endTime;
        Map<Integer, PartitionOffset> topicOffsets = getTopicOffsets(topic);
        List<TopicPartition> topicPartitions = toTopicPartitions(getPartitions(topic));
        return topicPartitions.stream()
                .collect(Collectors.toMap(TopicPartition::partition, t -> getRangePartitionOffset(t, searchStartTime, searchEndTime, topicOffsets)));
    }

    private PartitionOffset getRangePartitionOffset(TopicPartition topicPartition, Long startTime, Long endTime, Map<Integer, PartitionOffset> topicOffsets) {
        Optional<OffsetTimestamp> startOffset = getTopicPartitionOffset(topicPartition, startTime);
        Optional<OffsetTimestamp> endOffset = getTopicPartitionOffset(topicPartition, endTime);
        long start = 0L;
        if (startOffset.isPresent()) {
            start = startOffset.get().getOffset();
        }
        long end = 0L;
        if (endOffset.isPresent()) {
            end = endOffset.get().getOffset();
        } else {
            int partition = topicPartition.partition();
            if (topicOffsets.containsKey(partition)) {
                end = topicOffsets.get(topicPartition.partition()).getEnd();
                if (end > 0) {
                    // 因为topicOffsets的End是下次写入的offset，而不是系统里存在的最大offset
                    end = end - 1;
                }
            }
        }
        return PartitionOffset.builder()
                .start(start)
                .end(end)
                .build();
    }

    @Override
    public List<Node> bootstrapBrokers() {
        return toJavaList(adminClient.bootstrapBrokers());
    }

    @Override
    public Map<Node, NodeApiVersions> listAllBrokerVersionInfo() {
        return toJavaMap(adminClient.listAllBrokerVersionInfo()).entrySet()
                .stream()
                .collect(Collectors.toMap(Map.Entry::getKey, t -> t.getValue().get()));
    }

    @Override
    public Map<Node, List<String>> listAllConsumerGroups() {
        return toJavaMap(adminClient.listAllConsumerGroups()).entrySet()
                .stream()
                .collect(Collectors.toMap(Map.Entry::getKey, t -> this.groupOverviewToId(toJavaList(t.getValue()))));
    }

    @Override
    public ConsumerGroupSummary describeConsumerGroup(String groupId) {
        kafka.admin.AdminClient.ConsumerGroupSummary consumerGroupSummary = adminClient.describeConsumerGroup(groupId, 5000);
        Node coordinator = consumerGroupSummary.coordinator();
        String assignmentStrategy = consumerGroupSummary.assignmentStrategy();
        List<ConsumerSummary> consumerSummaries = toJavaList(consumerGroupSummary.consumers().get()).stream()
                .map(this::toConsumerSummary)
                .collect(Collectors.toList());
        return ConsumerGroupSummary.builder()
                .assignmentStrategy(assignmentStrategy)
                .consumerSummaries(consumerSummaries)
                .coordinator(coordinator)
                .build();
    }

    @Override
    public Map<TopicPartition, Long> listGroupOffsets(String groupId) {
        // Map<TopicPartition, Object> topicPartitionObjectMap = toJavaMap(adminClient.listGroupOffsets(groupId));
        // 上面API会报错 The broker only supports OffsetFetchRequest v1, but we need v2 or newer to request all topic partitions.

        // 故采用比较原始的API
        ConsumerGroupSummary consumerGroupSummary = describeConsumerGroup(groupId);
        Node coordinator = consumerGroupSummary.getCoordinator();

        BlockingChannel channel = new BlockingChannel(coordinator.host(), coordinator.port(),
                BlockingChannel.UseDefaultBufferSize(),
                BlockingChannel.UseDefaultBufferSize(),
                5000 /* read timeout in millis */);
        channel.connect();

        short version = 1; // version 1 and above fetch from Kafka, version 0 fetches from ZooKeeper

        OffsetFetchRequest fetchRequest = new OffsetFetchRequest(
                groupId,
                topicAndPartitionSeq(consumerGroupSummary),
                version,
                correlationId.incrementAndGet(),
                clientId);
        channel.send(fetchRequest);

        OffsetFetchResponse fetchResponse = OffsetFetchResponse.readFrom(channel.receive().payload(), version);
        Map<TopicAndPartition, OffsetMetadataAndError> metadataAndErrorMap = JavaConverters.mapAsJavaMapConverter(fetchResponse.requestInfo()).asJava();

        Map<TopicPartition, Long> topicPartitionOffset = metadataAndErrorMap.entrySet().stream().collect(Collectors.toMap(t -> new TopicPartition(t.getKey().topic(), t.getKey().partition()), t -> t.getValue().offset()));
        channel.disconnect();
        return topicPartitionOffset;
    }

    private Seq<TopicAndPartition> topicAndPartitionSeq(ConsumerGroupSummary consumerGroupSummary) {
        List<ConsumerSummary> consumerSummaries = consumerGroupSummary.getConsumerSummaries();
        List<TopicPartition> topicPartitions = consumerSummaries.stream().flatMap(t -> t.getTopicPartitions().stream()).collect(Collectors.toList());
        List<TopicAndPartition> topicAndPartitions = topicPartitions.stream().map(tp -> new TopicAndPartition(tp.topic(), tp.partition())).collect(Collectors.toList());
        return JavaConverters.asScalaIteratorConverter(topicAndPartitions.iterator()).asScala().toSeq();
    }

    private ConsumerSummary toConsumerSummary(kafka.admin.AdminClient.ConsumerSummary consumerSummary) {
        List<TopicPartition> topicPartitions = toJavaList(consumerSummary.assignment());
        String clientId = consumerSummary.clientId();
        String consumerId = consumerSummary.consumerId();
        String host = consumerSummary.host();
        return ConsumerSummary.builder()
                .topicPartitions(topicPartitions)
                .clientId(clientId)
                .consumerId(consumerId)
                .host(host)
                .build();
    }

    private List<String> groupOverviewToId(List<GroupOverview> groupOverviews) {
        return groupOverviews.stream().map(GroupOverview::groupId).collect(Collectors.toList());
    }

    private <T> List<T> toJavaList(Seq<T> seq) {
        return JavaConversions.seqAsJavaList(seq);
    }

    private <K, V> Map<K, V> toJavaMap(scala.collection.Map<K, V> map) {
        return JavaConversions.mapAsJavaMap(map);
    }

    @Override
    public void close() {
        adminClient.close();
        kafkaConsumer.close();
    }
}
